to_colorscale = function(x){
  if (x==5){
    return(2)
  }
  else if(x==10){
    return(3)
  }
  else if(x==20){
    return(4)
  }
  else if(x==50){
    return(5)
  }
  else if(x==100){
    return(6)
  }
  else if(x == 1){
    return(1)
  }
  else{
    return("error")
  }
}

to_colorscale2 = function(x){
  if (x==3){
    return(2)
  }
  else if(x==5){
    return(3)
  }
  else if(x==10){
    return(4)
  }
  else if(x==50){
    return(5)
  }
  else if(x==1){
    return(1)
  }
  else{
    return("error")
  }
}

to_colorscale3 = function(x){
  if(x==5){
    return(2)
  }
  else if(x==10){
    return(3)
  }
  else if(x==50){
    return(4)
  }
  else if(x==100){
    return(5)
  }
  else if(x==1){
    return(1)
  }
  else{
    return("error")
  }
}

extract_data = function(filepath, first_sweep, second_sweep, num_runs){
  fix_dat = data.frame(A=numeric(0), B=numeric(0), C=numeric(0), D=numeric(0))
  her_dat = data.frame(A=numeric(0), C=numeric(0), D=numeric(0))
  simple_dat = data.frame(A=numeric(0), C=numeric(0), D=numeric(0))
  for (i in first_sweep){
    for (j in second_sweep){
      filename = paste(filepath, j, "/", i, "/iftype2.oevo", sep="")
      temp = read_csv(filename, col_names=FALSE)
      fix = sum(temp$X2)/num_runs
      filename = paste(filepath, j, "/", i, "/type_1_tunnel.oevo", sep="")
      temp = read_csv(filename, col_names=FALSE)
      tunnel = sum(temp$X2)/num_runs
      tunnel = tunnel/fix
      fix_dat = rbind(fix_dat, data.frame(A=i,B=j,C=fix,D=tunnel))
    }
    filename = paste(filepath, "heritable", "/", i, "/iftype2.oevo", sep="")
    temp = read_csv(filename, col_names=FALSE)
    fix = sum(temp$X2)/num_runs
    filename = paste(filepath, "heritable", "/", i, "/type_1_tunnel.oevo", sep="")
    temp = read_csv(filename, col_names=FALSE)
    tunnel = sum(temp$X2)/num_runs
    tunnel = tunnel/fix
    her_dat = rbind(her_dat, data.frame(A=i,C=fix,D=tunnel))
    
    filename = paste(filepath, "simple", "/", i, "/iftype2.oevo", sep="")
    temp = read_csv(filename, col_names=FALSE)
    fix = sum(temp$X2)/num_runs
    filename = paste(filepath, "simple", "/", i, "/type_1_tunnel.oevo", sep="")
    temp = read_csv(filename, col_names=FALSE)
    tunnel = sum(temp$X2)/num_runs
    tunnel = tunnel/fix
    simple_dat = rbind(simple_dat, data.frame(A=i,C=fix,D=tunnel))
  }
  colnames(fix_dat) = c("first", "second", "fix", "tunnel")
  colnames(simple_dat) = c("first", "fix", "tunnel")
  colnames(her_dat) = c("first", "fix", "tunnel")
  
  fix_dat["fix_SEM"] = sqrt(fix_dat$fix * (1-fix_dat$fix)/num_runs)
  fix_dat["tunnel_SEM"] = sqrt(fix_dat$tunnel * (1-fix_dat$tunnel)/(fix_dat$fix * num_runs))
  simple_dat["fix_SEM"] = sqrt(simple_dat$fix * (1-simple_dat$fix)/num_runs)
  simple_dat["tunnel_SEM"] = sqrt(simple_dat$tunnel * (1-simple_dat$tunnel)/(simple_dat$fix * num_runs))
  her_dat["fix_SEM"] = sqrt(her_dat$fix * (1-her_dat$fix)/num_runs)
  her_dat["tunnel_SEM"] = sqrt(her_dat$tunnel * (1-her_dat$tunnel)/(her_dat$fix * num_runs))
  return(list(fix_dat=fix_dat, simple_dat=simple_dat, her_dat=her_dat))
}

extract_vars = function(filepath){
  variances = c()
  lower = c()
  upper = c()
  log_lower = c()
  log_upper = c()
  log_variances = c()
  input_variances = c()
  lifetimes = c()
  
  all_dat_lgnorm_hold = c()
  num_bootstrap = 1000
  for (k in 2:6){
    for (i in c(1,5,10,50,100)){
      accumulate = data.frame(A=numeric(0), B=numeric(0))
      for (j in 1:num_bootstrap){
        filename = paste(filepath, "lognorm_ss_dist/", i, "/", k, "/fit_sim_",j,"type_0.oevo", sep="")
        temp = read_csv(filename, skip=1, col_names=FALSE)
        est_mean = mean(unlist(t(temp[1,2:101])))
        to_plot = unlist(t(temp[1,2:101]))/est_mean
        est_var = var(to_plot)
        est_log_var = log10(est_var)
        
        accumulate = rbind(accumulate, c(est_var, est_log_var))
      }
      colnames(accumulate) = c("est_var", "est_log_var")
      variances = c(variances, mean(accumulate$est_var))
      var_cdf = ecdf(accumulate$est_var)
      lower = c(lower, quantile(var_cdf, 0.025))
      upper = c(upper, quantile(var_cdf, 0.975))
      var_cdf2 = ecdf(accumulate$est_log_var)
      log_lower = c(log_lower, quantile(var_cdf2, 0.025))
      log_upper = c(log_upper, quantile(var_cdf2, 0.975))
      log_variances = c(log_variances, mean(accumulate$est_log_var))
      input_variances = c(input_variances, k)
      lifetimes = c(lifetimes,i)
    }
  }
  return(data.frame(mean_var=variances, var_lower=lower, var_upper=upper, log_var=log_variances, log_lower=log_lower, log_upper=log_upper, input_var=input_variances, lifetimes=lifetimes))
}

extract_vars_heritable = function(filepath){
  variances = c()
  lower = c()
  upper = c()
  log_lower = c()
  log_upper = c()
  log_variances = c()
  input_variances = c()
  
  all_dat_lgnorm_hold = c()
  num_bootstrap = 1000
  for (k in 2:6){
      accumulate = data.frame(A=numeric(0), B=numeric(0))
      for (j in 1:num_bootstrap){
        filename = paste(filepath, "heritable_lognorm_ss_dist/", k, "/fit_sim_",j,"type_0.oevo", sep="")
        temp = read_csv(filename, skip=1, col_names=FALSE)
        est_mean = mean(unlist(t(temp[1,2:101])))
        to_plot = unlist(t(temp[1,2:101]))/est_mean
        est_var = var(to_plot)
        est_log_var = log10(est_var)
        
        accumulate = rbind(accumulate, c(est_var, est_log_var))
      }
      colnames(accumulate) = c("est_var", "est_log_var")
      variances = c(variances, mean(accumulate$est_var))
      var_cdf = ecdf(accumulate$est_var)
      lower = c(lower, quantile(var_cdf, 0.025))
      upper = c(upper, quantile(var_cdf, 0.975))
      var_cdf2 = ecdf(accumulate$est_log_var)
      log_lower = c(log_lower, quantile(var_cdf2, 0.025))
      log_upper = c(log_upper, quantile(var_cdf2, 0.975))
      log_variances = c(log_variances, mean(accumulate$est_log_var))
      input_variances = c(input_variances, k)
  }
  return(data.frame(mean_var=variances, var_lower=lower, var_upper=upper, log_var=log_variances, log_lower=log_lower, log_upper=log_upper, input_var=input_variances))
}

plot_fix_tunnel = function(fix_dat, simple_dat, her_dat){
  g1 = ggplot(data=fix_dat) + geom_line(aes(x=first, y=fix, color=sapply(second, FUN=to_colorscale), group=second)) + theme_bw() + geom_errorbar(aes(x=first,ymin=fix-fix_SEM*1.96, ymax=fix+fix_SEM*1.96, color=sapply(second, FUN=to_colorscale)), width=0) + geom_line(data=simple_dat, aes(x=first, y=fix), color="#707070", size=1) + geom_errorbar(data=simple_dat, aes(x=first, ymin=fix-fix_SEM*1.96, ymax=fix+fix_SEM*1.96), color="#707070", width=0) + geom_line(data=her_dat, aes(x=first, y=fix), color="#4400cc", size=1) + geom_errorbar(data=her_dat, aes(x=first, ymin=fix-fix_SEM*1.96, ymax=fix+fix_SEM*1.96), color="#4400cc", width=0) + scale_color_gradient2(low="#cccc00", mid="#ff3399", high="#9900ff", midpoint=4) + theme(legend.position = "none", axis.title.x=element_blank(), axis.title.y=element_blank(), text = element_text(size=20), panel.grid.major = element_blank(), panel.grid.minor = element_blank(), axis.line = element_line(colour = "black"), panel.border = element_blank())
  g2 = ggplot(data=fix_dat) + geom_line(aes(x=first, y=tunnel, color=sapply(second, FUN=to_colorscale), group=second)) + theme_bw() + geom_errorbar(aes(x=first,ymin=tunnel-tunnel_SEM*1.96, ymax=tunnel+tunnel_SEM*1.96, color=sapply(second, FUN=to_colorscale)), width=0) + geom_line(data=simple_dat, aes(x=first, y=tunnel), color="#707070", size=1) + geom_errorbar(data=simple_dat, aes(x=first, ymin=tunnel-tunnel_SEM*1.96, ymax=tunnel+tunnel_SEM*1.96), color="#707070", width=0) + geom_line(data=her_dat, aes(x=first, y=tunnel), color="#4400cc", size=1) + geom_errorbar(data=her_dat, aes(x=first, ymin=tunnel-tunnel_SEM*1.96, ymax=tunnel+tunnel_SEM*1.96), color="#4400cc", width=0)+ scale_color_gradient2(low="#cccc00", mid="#ff3399", high="#9900ff", midpoint=4) + theme(legend.position = "none", axis.title.x=element_blank(), axis.title.y=element_blank(), text = element_text(size=20), panel.grid.major = element_blank(), panel.grid.minor = element_blank(), axis.line = element_line(colour = "black"), panel.border = element_blank())
  return(list(fix=g1, tunnel=g2))
}

plot_dists = function(dist_data){
  to_plot = dist_data[dist_data$lifetimes %in% c(1,3,5,10,50),]
  her_plot = dist_data[dist_data$lifetimes == "heritable",]
  ggplot(data=to_plot) + geom_line(aes(x=fit, group=lifetimes, color=sapply(lifetimes,FUN=to_colorscale2)), stat="density") + theme_bw()  + scale_color_gradient2(low="#cccc00", mid="#ff3399", high="#9900ff", midpoint=4) + geom_line(data=her_plot, aes(x=fit),  stat="density", color="#4400cc") + theme(legend.position = "none", axis.title.x=element_blank(), axis.title.y=element_blank(), text = element_text(size=20), panel.grid.major = element_blank(), panel.grid.minor = element_blank(), axis.line = element_line(colour = "black"), panel.border = element_blank())
}

plot_dists_var = function(dist_data){
  ggplot(data=dist_data) + geom_line(aes(x=fit, group=input_variances, color=-input_variances), stat="density") + theme_bw() + scale_color_gradient2(low="#cccc00", mid="#00cc00", high="#0066ff", midpoint=-4) + theme(legend.position = "none", axis.title.x=element_blank(), axis.title.y=element_blank(), text = element_text(size=20), panel.grid.major = element_blank(), panel.grid.minor = element_blank(), axis.line = element_line(colour = "black"), panel.border = element_blank())
}
